from functools import partial
import os

import numpy as np
import pandas as pd
import scipy.stats as st
import fire

from plot import plot_se_comp, plot_intersectional, plot_cat_comps

np.random.seed(42)

ATT_WTS = [
    1.08,
    1.56,
    2.72,
    1.13,
    0.77,
    0.52,
    0.93,
    1.35,
    0.69,
    0.42,
    0.47,
    0.45,
    0.51,
    0.5,
    0.37,
]

CLICK_WTS = [
    8.77,
    11.45,
    16.37,
    10.46,
    4.19,
    2.08,
    7.51,
    11.53,
    4.06,
    0.54,
    3.78,
    2.30,
    3.29,
    3.22,
    2.17,
]


def add_labeler_wt(df):
    """add unique labeler weight"""
    df["labeler_wt"] = 1 / len(df)
    return df


def add_person_wt(df):
    """add unique person weight"""
    p_ids = df.person_id.unique()
    person_wts = {p_id: 1 / len(p_ids) for p_id in p_ids}
    df["person_wt"] = df.person_id.apply(lambda i: person_wts[i])
    return df


def add_skin_tone_labeler_wt(df):
    probs = df.worker_skin_tone.value_counts(normalize=True).to_frame()
    probs.columns = ["observed"]
    probs["uniform"] = 1 / df.worker_skin_tone.nunique()
    weights = {i: probs.uniform[i] / probs.observed[i] for i in probs.index}
    df["st_labeler_wt"] = df.worker_skin_tone.apply(lambda st: weights[st])
    return df


def prep_data(df, add_st_wt=False):
    df["feminine_presenting"] = df.gender == "feminine_presenting"
    df = df.groupby(
        ["sengine", "qry_id", "img_rank", "img_id", "person_id"], group_keys=False
    ).apply(add_labeler_wt)
    df = df.groupby("img_id", group_keys=False).apply(add_person_wt)
    df = add_skin_tone_labeler_wt(df)

    # bin age values
    bins = [i * 10 for i in range(10)]
    labels = [f"{bins[i]}-{bins[i+1]}" for i in range(len(bins) - 1)]
    df["age_bin"] = pd.cut(df.age, bins=bins, labels=labels)

    # attention weight
    sum_att = sum(ATT_WTS)
    att_wts = [w / sum_att for w in ATT_WTS]
    df["att_wt"] = df.img_rank.apply(lambda r: att_wts[r])

    # click weight
    sum_ctr = sum(CLICK_WTS)
    click_wts = [w / sum_ctr for w in CLICK_WTS]
    df["click_wt"] = df.img_rank.apply(lambda r: click_wts[r])

    df["total_wt"] = df.labeler_wt * df.person_wt * df.click_wt
    if add_st_wt:
        df["total_wt"] *= df.st_labeler_wt

    df["category"] = df.category.replace({"telecommunication": "telecomms"})
    return df


def worker_label_correlations(df):
    st_corr = st.pearsonr(df.skin_tone, df.worker_skin_tone)
    print(f"skin tone correlation: {st_corr}")

    age_corr = st.pearsonr(df.age, df.worker_age)
    print(f"age correlation: {age_corr}")

    gender_ct = pd.crosstab(
        df.gender.reset_index(drop=True),
        df.worker_gender.rename(
            {
                "male": "masculine_presenting",
                "female": "feminine_presenting",
                "non_binary": "non_binary_presenting",
                "none": "unsure",
            }
        ).reset_index(drop=True),
    )
    gender_corr = st.contingency.association(gender_ct)
    print(f"gender association: {gender_corr}")


def resample_qrys(df):
    qry_ids = df.qry_id.drop_duplicates()
    sample_ids = qry_ids.sample(frac=1, replace=True)
    bs_cts = sample_ids.value_counts().reset_index()
    bs_cts.columns = ["qry_id", "bs_wt"]
    assert bs_cts.bs_wt.sum() == len(qry_ids)
    return pd.merge(df, bs_cts, on="qry_id", validate="m:1")


def wt_sum(df, wt_col="total_wt"):
    data = []
    for cat in ["gender", "skin_tone", "age_bin"]:
        cat_sum = df.groupby(cat)[wt_col].sum()
        cat_fracs = cat_sum / cat_sum.sum()
        data.append(cat_fracs)
    return pd.concat(data)


def get_mean_st(d, wt_col="total_wt"):
    return sum(d.skin_tone * d[wt_col]) / sum(d[wt_col])


def extra_stats(d, orig, wt_col="total_wt"):
    d["female-male"] = d.feminine_presenting - d.masculine_presenting
    d["skin_tone<=3"] = d[[1, 2, 3]].sum(axis=1)
    d["mean_skin_tone"] = orig.groupby("sengine").apply(
        partial(get_mean_st, wt_col=wt_col)
    )
    d["skin_tone_google-bing"] = (
        d.loc["google", "mean_skin_tone"] - d.loc["bing", "mean_skin_tone"]
    )
    d["age<=40"] = d[["0-10", "10-20", "20-30", "30-40"]].sum(axis=1)
    return d


def bootstrap(data, grp_key, sum_fn, n_bs=1000, alpha=0.05):
    stats = []
    for _ in range(n_bs):
        bs_sample = resample_qrys(data)
        bs_sample = bs_sample.reset_index(drop=True)
        bs_sample["wt"] = bs_sample.total_wt * bs_sample.bs_wt
        bs_stats = bs_sample.groupby(grp_key).apply(partial(sum_fn, wt_col="wt"))
        bs_stats = extra_stats(bs_stats, bs_sample, wt_col="wt")
        stats.append(bs_stats)
    stats = pd.concat(stats)

    ci = stats.groupby(grp_key).apply(
        lambda x: x.quantile(q=[alpha / 2, 1 - alpha / 2])
    )

    mu = data.groupby(grp_key).apply(sum_fn)
    mu = extra_stats(mu, data).reset_index()
    ci = ci.reset_index()
    drop_cols = [c for c in ci.columns if str(c).startswith("level")]
    ci = ci.drop(columns=drop_cols)
    return pd.concat([mu, ci])


def bootstrap_intersectional(data, cats, n_bs=1000, alpha=0.05):
    stats = []
    for _ in range(n_bs):
        bs_sample = resample_qrys(data)
        bs_sample = bs_sample.reset_index(drop=True)
        bs_sample["wt"] = bs_sample.total_wt * bs_sample.bs_wt

        bs_stats = bs_sample.groupby(cats).wt.sum() / bs_sample.wt.sum()
        stats.append(bs_stats)

    stats = pd.concat(stats)
    ci = stats.groupby(cats).apply(lambda x: x.quantile(q=[alpha / 2, 1 - alpha / 2]))
    mu = data.groupby(cats).total_wt.sum() / data.total_wt.sum()

    return pd.concat([mu, ci])


def category_summary(data, fp_out="tables/category_summary.tex"):
    cat_sum = data.groupby("category").apply(
        lambda df: pd.Series(
            {
                "n_queries": df.qry_id.nunique(),
                "faces_per_image": df.drop_duplicates(subset="img_id").n_faces.mean(),
            }
        )
    )
    cat_sum = cat_sum.reset_index().sort_values(by="n_queries", ascending=False)
    cat_sum["n_queries"] = cat_sum["n_queries"].astype(int)
    cat_sum["faces_per_image"] = cat_sum["faces_per_image"].round(2)
    os.makedirs("tables", exist_ok=True)
    cat_sum.to_latex(fp_out, index=False)


def qry_average(df):
    df["feminine_presenting"] = df["feminine_presenting"].astype(int)
    fp = sum(df.feminine_presenting * df.total_wt) / sum(df.total_wt)
    st = sum(df.skin_tone * df.total_wt) / sum(df.total_wt)
    age = sum(df.age * df.total_wt) / sum(df.total_wt)
    return pd.Series(
        {
            "feminine_presenting": fp,
            "skin_tone": st,
            "age": age,
            "category": df.category.iloc[0],
        }
    )


def get_category_stats(df, col, baseline):
    x = df[col]
    l, u = st.t.interval(0.95, len(x) - 1, np.mean(x), st.sem(x))
    pval = st.ttest_1samp(x, baseline).pvalue
    return pd.Series({"est": np.mean(x), "lb": l, "ub": u, "pval": pval})


def pval_to_stars(pval):
    if pval < 0.001:
        return "***"
    elif pval < 0.01:
        return "**"
    elif pval < 0.05:
        return "*"
    else:
        return ""


def get_stats(data, col, baseline):
    stats = data.groupby("category").apply(
        partial(get_category_stats, col=col, baseline=baseline)
    )
    stats["adj_pval"] = st.false_discovery_control(stats.pval, method="bh")
    stats = stats.reset_index().sort_values(by="est")
    stats["adj_category"] = stats.apply(
        lambda row: row.category + pval_to_stars(row.adj_pval), axis=1
    )
    return stats


def main(fp_input="img_data.csv", n_boot=1000, add_st_wt=False):

    data = pd.read_csv(fp_input)
    data = prep_data(data, add_st_wt=add_st_wt)
    category_summary(data)

    # 1. representation across search engines
    stats_se = bootstrap(data, "sengine", wt_sum, n_bs=n_boot)
    stats_se.index.name = "sengine"
    cols = [
        "female-male",
        "skin_tone<=3",
        "mean_skin_tone",
        "skin_tone_google-bing",
        "age<=40",
    ]
    print(stats_se[cols])
    os.makedirs("plots", exist_ok=True)
    plot_se_comp(stats_se, add_st_wt)

    if not add_st_wt:
        # 2. intersectional representation
        stats_st_gender = bootstrap_intersectional(
            data, ["skin_tone", "gender"], n_bs=n_boot
        )
        stats_age_gender = bootstrap_intersectional(
            data, ["age_bin", "gender"], n_bs=n_boot
        )
        plot_intersectional(
            stats_st_gender,
            "skin_tone",
            fp_out="plots/results_st_gender.pdf",
        )
        plot_intersectional(
            stats_age_gender,
            "age_bin",
            fp_out="plots/results_age_gender.pdf",
        )

    # 3. representation across categories
    qry_avg_data = data.groupby("qry_id").apply(qry_average)
    stats = {
        "feminine_presenting": get_stats(qry_avg_data, "feminine_presenting", 0.504),
        "skin_tone": get_stats(qry_avg_data, "skin_tone", 5.5),
        "age": get_stats(qry_avg_data, "age", 38.9),
    }
    plot_cat_comps(stats)


if __name__ == "__main__":
    fire.Fire(main)
